{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i3puWgvKeyWu"
      },
      "source": [
        "# Auto-Batched Joint Distributions: A Gentle Tutorial"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZrwVQsM9TiUw"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "form",
        "id": "CpDUTVKYTowI"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ltPJCG6pAUoc"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/probability/examples/Modeling_with_JointDistribution\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Modeling_with_JointDistribution.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Modeling_with_JointDistribution.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/probability/tensorflow_probability/examples/jupyter_notebooks/Modeling_with_JointDistribution.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zzaOJSXagzMY"
      },
      "source": [
        "### Introduction"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cIvB2CSBe49Z"
      },
      "source": [
        "TensorFlow Probability (TFP) offers a number of `JointDistribution` abstractions that make probabilistic inference easier by allowing a user to easily express a probabilistic graphical model in a near-mathematical form; the abstraction generates methods for sampling from the model and evaluating the log probability of samples from the model. In this tutorial, we review \"autobatched\" variants, which were developed after the original `JointDistribution` abstractions. Relative to the original, non-autobatched abstractions, the autobatched versions are simpler to use and more ergonomic, allowing many models to be expressed with less boilerplate. In this colab, we explore a simple model in (perhaps tedious) detail, making clear the problems autobatching solves, and (hopefully) teaching the reader more about TFP shape concepts along the way.\n",
        "\n",
        "Prior to the introduction of autobatching, there were a few different variants of `JointDistribution`, corresponding to different syntactic styles for expressing probabilistic models: `JointDistributionSequential`, `JointDistributionNamed`, and`JointDistributionCoroutine`. Auobatching exists as a mixin, so we now have `AutoBatched` variants of all of these. In this tutorial, we explore the differences between `JointDistributionSequential` and `JointDistributionSequentialAutoBatched`; however, everything we do here is applicable to the other variants with essentially no changes.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uiR4-VOt9NFX"
      },
      "source": [
        "### Dependencies & Prerequisites\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "coUnDhkpT5_6"
      },
      "outputs": [],
      "source": [
        "#@title Import and set ups{ display-mode: \"form\" }\n",
        "\n",
        "import functools\n",
        "import numpy as np\n",
        "\n",
        "import tensorflow.compat.v2 as tf\n",
        "tf.enable_v2_behavior()\n",
        "\n",
        "import tensorflow_probability as tfp\n",
        "\n",
        "tfd = tfp.distributions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KohBmaTn5W7I"
      },
      "source": [
        "### Prerequisite: A Bayesian Regression Problem"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vChyK0vr9XD8"
      },
      "source": [
        "We'll consider a very simple Bayesian regression scenario:\n",
        "\\begin{align*}\n",
        "m & \\sim \\text{Normal}(0, 1) \\\\\n",
        "b & \\sim \\text{Normal}(0, 1) \\\\\n",
        "Y & \\sim \\text{Normal}(mX + b, 1)\n",
        "\\end{align*}\n",
        "\n",
        "In this model, `m` and `b` are drawn from standard normals, and the observations `Y` are drawn from a normal distribution whose mean depends on the random variables `m` and `b`, and some (nonrandom, known) covariates `X`. (For simplicity, in this example, we assume the scale of all random variables is known.)\n",
        "\n",
        "To perform inference in this model, we'd need to know both the covariates `X` and the observations `Y`, but for the purposes of this tutorial, we'll only need `X`, so we define a simple dummy `X`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "UIpJ_cXUVabB"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([0, 1, 2, 3, 4, 5, 6])"
            ]
          },
          "execution_count": 3,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "X = np.arange(7)\n",
        "X"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CIBpupyt9GTT"
      },
      "source": [
        "### Desiderata"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j2uzL_uI9tqO"
      },
      "source": [
        "In probabilistic inference, we often want to perform two basic operations:\n",
        "- `sample`: Drawing samples from the model.\n",
        "- `log_prob`: Computing the log probability of a sample from the model.\n",
        "\n",
        "The key contribution of TFP's `JointDistribution` abstractions (as well as of many other approaches to probabilistic programming) is to allow users to write a model *once* and have access to both `sample` and `log_prob` computations.\n",
        "\n",
        "Noting that we have 7 points in our data set (`X.shape = (7,)`), we can now state the desiderata for an excellent `JointDistribution`:\n",
        "\n",
        "* `sample()` should produce a list of `Tensors` having shape `[(), (), (7,)`], corresponding to the scalar slope, scalar bias, and vector observations, respectively.\n",
        "* `log_prob(sample())` should produce a scalar: the log probability of a particular slope, bias, and observations.\n",
        "* `sample([5, 3])` should produce a list of `Tensors` having shape `[(5, 3), (5, 3), (5, 3, 7)]`, representing a `(5, 3)`-*batch* of samples from the model.\n",
        "* `log_prob(sample([5, 3]))` should produce a `Tensor` with shape (5, 3).\n",
        "\n",
        "We'll now look at a succession of `JointDistribution` models, see how to achieve the above desiderata, and hopefully learn a little more about TFP shapes along the way. \n",
        "\n",
        "Spoiler alert: The approach that satisfies the above  desiderata without added boilerplate is [autobatching](#scrollTo=_h7sJ2bkfOS7). "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QiII0ypZcyTY"
      },
      "source": [
        "### First Attempt; `JointDistributionSequential`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "kY501q-QVR9g"
      },
      "outputs": [],
      "source": [
        "jds = tfd.JointDistributionSequential([\n",
        "    tfd.Normal(loc=0., scale=1.),   # m\n",
        "    tfd.Normal(loc=0., scale=1.),   # b\n",
        "    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y\n",
        "])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hzNPPqJ-BwA-"
      },
      "source": [
        "This is more or less a direct translation of the model into code. The slope `m` and bias `b` are straightforward. `Y` is defined using a `lambda`-function: the general pattern is that a `lambda`-function of $k$ arguments in a `JointDistributionSequential` (JDS) uses the previous $k$ distributions in the model. Note the \"reverse\" order."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5jIvsQSOD81N"
      },
      "source": [
        "We'll call `sample_distributions`, which returns both a sample *and* the underlying \"sub-distributions\" that were used to generate the sample. (We could have produced just the sample by calling `sample`; later in the tutorial it will be convenient to have the distributions as well.) The sample we produce is fine:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "y05IrsfiaxCh"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[<tf.Tensor: shape=(), dtype=float32, numpy=0.08079692>,\n",
              " <tf.Tensor: shape=(), dtype=float32, numpy=-1.5032883>,\n",
              " <tf.Tensor: shape=(7,), dtype=float32, numpy=\n",
              " array([-1.906176  ,  0.53724945, -0.30291188, -0.86593336, -0.00641394,\n",
              "        -0.58248115, -2.907504  ], dtype=float32)>]"
            ]
          },
          "execution_count": 5,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dists, s = jds.sample_distributions()\n",
        "s"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o7E1WkoCEB12"
      },
      "source": [
        "But `log_prob` produces a result with an undesired shape:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "xR0lbgjNay4X"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(7,), dtype=float32, numpy=\n",
              "array([-3.9711766, -5.8103094, -4.429552 , -3.9680157, -4.578788 ,\n",
              "       -4.02357  , -5.674173 ], dtype=float32)>"
            ]
          },
          "execution_count": 6,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds.log_prob(s)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1mMIs28LEJqN"
      },
      "source": [
        "And multiple sampling doesn't work:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "LbfRiIsfc9Hf"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Incompatible shapes: [5,3] vs. [7] [Op:Mul]\n"
          ]
        }
      ],
      "source": [
        "try:\n",
        "  jds.sample([5, 3])\n",
        "except tf.errors.InvalidArgumentError as e:\n",
        "  print(e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rnvtz3SQHrVL"
      },
      "source": [
        "Let's try to understand what's going wrong."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dp30JPCmHyuz"
      },
      "source": [
        "### A Brief Review: Batch and Event Shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w24fZn3kH2uF"
      },
      "source": [
        "In TFP, an ordinary (not a `JointDistribution`) probability distribution has an *event shape* and a *batch shape*, and understanding the difference is crucial to effective use of TFP:\n",
        "\n",
        "* Event shape describes the shape of a single draw from the distribution; the draw may be dependent across dimensions. For scalar distributions, the event shape is []. For a 5-dimensional MultivariateNormal, the event shape is [5].\n",
        "* Batch shape describes independent, not identically distributed draws, aka a \"batch\" of distributions. Representing a batch of distributions in a single Python object is one of the key ways TFP achieves efficiency at scale.\n",
        "\n",
        "For our purposes, a critical fact to keep in mind is that if we call `log_prob` on a single sample from a distribution, the result will always have a shape that matches (i.e., has as rightmost dimensions) the *batch* shape.\n",
        "\n",
        "For a more in-depth discussion of shapes, see [the \"Undersanding TensorFlow Distributions Shapes\" tutorial](https://www.tensorflow.org/probability/examples/Understanding_TensorFlow_Distributions_Shapes).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nONZMjl-KtTz"
      },
      "source": [
        "### Why Doesn't `log_prob(sample())` Produce a Scalar? "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VUKyGzkOJiuD"
      },
      "source": [
        "Let's use our knowledge of batch and event shape to explore what's happening with `log_prob(sample())`. Here's our sample again:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "ijRGAnSBJwCG"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[<tf.Tensor: shape=(), dtype=float32, numpy=0.08079692>,\n",
              " <tf.Tensor: shape=(), dtype=float32, numpy=-1.5032883>,\n",
              " <tf.Tensor: shape=(7,), dtype=float32, numpy=\n",
              " array([-1.906176  ,  0.53724945, -0.30291188, -0.86593336, -0.00641394,\n",
              "        -0.58248115, -2.907504  ], dtype=float32)>]"
            ]
          },
          "execution_count": 8,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "s"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NAzBAsu3OoLv"
      },
      "source": [
        "And here are our distributions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "_xtIUKf8Nq3G"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,\n",
              " <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,\n",
              " <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]"
            ]
          },
          "execution_count": 9,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dists"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LzkLnoZyFeU_"
      },
      "source": [
        "The log probability is computed by summing the log probabilities of the sub-distributions at the (matched) elements of the parts:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "5XTDKVMPO5qg"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[<tf.Tensor: shape=(), dtype=float32, numpy=-0.9222026>,\n",
              " <tf.Tensor: shape=(), dtype=float32, numpy=-2.0488763>,\n",
              " <tf.Tensor: shape=(7,), dtype=float32, numpy=\n",
              " array([-1.0000978 , -2.8392305 , -1.4584732 , -0.99693686, -1.6077087 ,\n",
              "        -1.0524913 , -2.703094  ], dtype=float32)>]"
            ]
          },
          "execution_count": 10,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "log_prob_parts = [dist.log_prob(ss) for (dist, ss) in zip(dists, s)]\n",
        "log_prob_parts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "QoWsVGx8N1IJ"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>"
            ]
          },
          "execution_count": 11,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "np.sum(log_prob_parts) - jds.log_prob(s)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZJFvR4ZNFngd"
      },
      "source": [
        "So, one level of explanation is that the log probability calculation is returning a 7-Tensor because the third subcomponent of `log_prob_parts` is a 7-Tensor. But why?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zdpKnguOPOrr"
      },
      "source": [
        "Well, we see that the last element of `dists`, which corresponds to our distribution over `Y` in the mathematial formulation, has a `batch_shape` of `[7]`. In other words, our distribution over `Y` is a batch of 7 independent normals (with different means and, in this case, the same scale)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0WXzlR_diTuZ"
      },
      "source": [
        "We now understand what's wrong: in JDS, the distribution over `Y` has `batch_shape=[7]`, a sample from the JDS represents scalars for `m` and `b` and a \"batch\" of 7 independent normals. and `log_prob` computes 7 separate log-probabilities, each of which represents the log probability of drawing `m` and `b` and a single observation `Y[i]` at some `X[i]`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s9RI0oxCi_En"
      },
      "source": [
        "### Fixing `log_prob(sample())` with `Independent`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EOL1hllzjDcF"
      },
      "source": [
        "Recall that `dists[2]` has `event_shape=[]` and `batch_shape=[7]`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "TA05J9VwjCLu"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>"
            ]
          },
          "execution_count": 12,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dists[2]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_xQ5ORIqjPAz"
      },
      "source": [
        "By using TFP's `Independent` metadistribution, which converts batch dimensions to event dimensions, we can convert this into a distribution with `event_shape=[7]` and `batch_shape=[]` (we'll rename it `y_dist_i` because it's a distribution on `Y`, with the `_i` standing in for our `Independent` wrapping): "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "Aa_SPItTjLBO"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>"
            ]
          },
          "execution_count": 13,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)\n",
        "y_dist_i"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JrRjuDhhmBEr"
      },
      "source": [
        "Now, the `log_prob` of a 7-vector is a scalar:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "y9yZs-kwdLGa"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(), dtype=float32, numpy=-11.658031>"
            ]
          },
          "execution_count": 14,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "y_dist_i.log_prob(s[2])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RqNEen4Ujkhh"
      },
      "source": [
        "Under the covers, `Independent` sums over the batch:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "SxYr1McJkWFx"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(), dtype=float32, numpy=0.0>"
            ]
          },
          "execution_count": 15,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "y_dist_i.log_prob(s[2]) - tf.reduce_sum(dists[2].log_prob(s[2]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "00lD003YkojA"
      },
      "source": [
        "And indeed, we can use this to construct a new `jds_i` (the `i` again stands for `Independent`) where `log_prob` returns a scalar:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "1jwoSeNWkhT6"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(), dtype=float32, numpy=-14.62911>"
            ]
          },
          "execution_count": 16,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds_i = tfd.JointDistributionSequential([\n",
        "    tfd.Normal(loc=0., scale=1.),   # m\n",
        "    tfd.Normal(loc=0., scale=1.),   # b\n",
        "    lambda b, m: tfd.Independent(   # Y\n",
        "        tfd.Normal(loc=m*X + b, scale=1.),\n",
        "        reinterpreted_batch_ndims=1)\n",
        "])\n",
        "\n",
        "jds_i.log_prob(s)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hYY3CNBXlAIZ"
      },
      "source": [
        "A couple notes:\n",
        "- `jds_i.log_prob(s)` is *not* the same as `tf.reduce_sum(jds.log_prob(s))`. The former produces the \"correct\" log probability of the joint distribution. The latter sums over a 7-Tensor, each element of which is the sum of the log probability of `m`, `b`, and a single element of the log probability of `Y`, so it overcounts `m` and `b`. (`log_prob(m) + log_prob(b) + log_prob(Y)` returns a result rather than throwing an exception because TFP follows TF and NumPy's broadcasting rules; adding a scalar to a vector produces a vector-sized result.)\n",
        "- In this particular case, we could have solved the problem and achieved the same result using `MultivariateNormalDiag` instead of `Independent(Normal(...))`. `MultivariateNormalDiag` is a vector-valued distribution (i.e., it already has vector event-shape). Indeeed `MultivariateNormalDiag` could be (but isn't) implemented as a composition of `Independent` and `Normal`. It's worthwhile to remember that given a vector `V`, samples from `n1 = Normal(loc=V)`, and `n2 = MultivariateNormalDiag(loc=V)` are indistinguishable; the difference beween these distributions is that `n1.log_prob(n1.sample())` is a vector and `n2.log_prob(n2.sample())` is a scalar."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b-iFi65ZmvpB"
      },
      "source": [
        "### Multiple Samples?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PZcEBJS_nAhA"
      },
      "source": [
        "Drawing multiple samples still doesn't work:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "PkvYmB3jm2sI"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Incompatible shapes: [5,3] vs. [7] [Op:Mul]\n"
          ]
        }
      ],
      "source": [
        "try:\n",
        "  jds_i.sample([5, 3])\n",
        "except tf.errors.InvalidArgumentError as e:\n",
        "  print(e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b9Jh0MTCn0Mr"
      },
      "source": [
        "Let's think about why. When we call `jds_i.sample([5, 3])`, we'll first draw samples for `m` and `b`, each with shape `(5, 3)`. Next, we're going to try to construct a `Normal` distribution via:\n",
        "```\n",
        "tfd.Normal(loc=m*X + b, scale=1.)\n",
        "```\n",
        "\n",
        "But if `m` has shape `(5, 3)` and `X` has shape `7`, we can't multiply them together, and indeed this is the error we're hitting:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "ei9Z2Nozp8Dy"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Incompatible shapes: [5,3] vs. [7] [Op:Mul]\n"
          ]
        }
      ],
      "source": [
        "m = tfd.Normal(0., 1.).sample([5, 3])\n",
        "try:\n",
        "  m * X\n",
        "except tf.errors.InvalidArgumentError as e:\n",
        "  print(e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1uqaIx2LlaeP"
      },
      "source": [
        "To resolve this issue, let's think about what properties the distribution over `Y` has to have. If we've called `jds_i.sample([5, 3])`, then we know `m` and `b` will both have shape `(5, 3)`. What shape should a call to `sample` on the `Y` distribution produce? The obvious answer is `(5, 3, 7)`: for each batch point, we want a sample with the same size as `X`. We can achieve this by using TensorFlow's broadcasting capabilities, adding extra dimensions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "-22Bg8Yfr6tg"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "TensorShape([5, 3, 1])"
            ]
          },
          "execution_count": 19,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "m[..., tf.newaxis].shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "7k21MOvlsHGe"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "TensorShape([5, 3, 7])"
            ]
          },
          "execution_count": 20,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "(m[..., tf.newaxis] * X).shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5AEBbcjVsXQR"
      },
      "source": [
        "Adding an axis to both `m` and `b`, we can define a new JDS that supports multiple samples:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "9rJ9WCVQsW0S"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
              " array([[-1.0641694 ,  0.88205844, -0.3132895 ],\n",
              "        [ 0.7708484 , -0.08183189, -0.543864  ],\n",
              "        [-0.46075284, -1.8269578 ,  0.30572248],\n",
              "        [-1.4730763 , -1.749881  , -0.18791775],\n",
              "        [-1.1432608 , -0.03570032, -0.47378683]], dtype=float32)>,\n",
              " <tf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
              " array([[-0.05995331,  0.4670131 ,  0.39853612],\n",
              "        [-0.50897926,  0.55372673, -0.44930768],\n",
              "        [ 2.12264   , -0.8941609 , -0.22456498],\n",
              "        [-0.28325766, -0.6039566 , -0.7982028 ],\n",
              "        [ 1.6194319 , -1.5981796 , -1.0267515 ]], dtype=float32)>,\n",
              " <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=\n",
              " array([[[ 1.23666501e+00, -2.72573185e+00, -1.06902647e+00,\n",
              "          -3.99592471e+00, -5.51451778e+00, -4.81725502e+00,\n",
              "          -5.56694984e+00],\n",
              "         [ 4.49036419e-01,  8.42256904e-01,  1.65697992e+00,\n",
              "           2.83218813e+00,  2.02821064e+00,  5.30640173e+00,\n",
              "           7.88480282e+00],\n",
              "         [ 1.04598761e-01, -6.91915929e-01, -1.45380819e+00,\n",
              "          -2.99107218e+00, -7.48243392e-01, -2.39095449e-01,\n",
              "          -2.95680499e+00]],\n",
              " \n",
              "        [[ 1.42532825e+00, -3.29503775e-01,  2.74825788e+00,\n",
              "           4.71045971e-01,  2.95442867e+00,  5.41281986e+00,\n",
              "           4.12423992e+00],\n",
              "         [ 1.13851607e+00,  1.34247184e+00,  5.38553715e-01,\n",
              "           4.75679219e-01,  1.15889467e-01,  2.28273201e+00,\n",
              "           1.66366085e-01],\n",
              "         [-8.50983739e-01, -2.25449228e+00, -1.62029576e+00,\n",
              "          -2.47048473e+00, -8.28547478e-02, -1.62208068e+00,\n",
              "          -3.38254881e+00]],\n",
              " \n",
              "        [[ 2.75410676e+00,  1.73929715e+00,  1.65932381e+00,\n",
              "           1.43238759e+00,  7.23003149e-01, -4.07665223e-01,\n",
              "          -5.24324298e-01],\n",
              "         [-3.93893182e-01, -1.79903293e+00, -3.79906535e+00,\n",
              "          -4.41074371e+00, -9.76827240e+00, -9.46045876e+00,\n",
              "          -1.14899712e+01],\n",
              "         [-1.37748170e+00,  5.45929432e-01, -8.51358235e-01,\n",
              "           2.76324749e-02,  5.16971350e-01, -6.29880428e-01,\n",
              "           2.23690033e+00]],\n",
              " \n",
              "        [[ 2.06451297e+00, -2.04346943e+00, -3.22309828e+00,\n",
              "          -5.45961189e+00, -5.86767960e+00, -7.99706030e+00,\n",
              "          -8.01118088e+00],\n",
              "         [-1.71845675e+00, -2.55129766e+00, -2.98688173e+00,\n",
              "          -4.69979382e+00, -6.89284897e+00, -1.11423817e+01,\n",
              "          -1.29737835e+01],\n",
              "         [-1.13922238e-01, -1.64989650e-01, -1.72910857e+00,\n",
              "          -2.97116470e+00, -2.48031807e+00, -2.05811620e+00,\n",
              "          -1.51430011e+00]],\n",
              " \n",
              "        [[ 2.13675165e+00,  1.30672932e+00, -3.27593088e-03,\n",
              "          -1.38755083e+00, -1.46972406e+00, -3.88024116e+00,\n",
              "          -4.52536440e+00],\n",
              "         [-2.77965927e+00, -1.04991031e+00, -1.96163297e+00,\n",
              "          -1.44081473e+00, -6.46156311e-01, -3.07756782e+00,\n",
              "          -3.05591631e+00],\n",
              "         [-6.00465536e-01, -1.80835783e+00, -2.14595556e+00,\n",
              "          -2.22402120e+00, -2.26174808e+00, -3.47439361e+00,\n",
              "          -3.31842375e+00]]], dtype=float32)>]"
            ]
          },
          "execution_count": 21,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds_ia = tfd.JointDistributionSequential([\n",
        "    tfd.Normal(loc=0., scale=1.),   # m\n",
        "    tfd.Normal(loc=0., scale=1.),   # b\n",
        "    lambda b, m: tfd.Independent(   # Y\n",
        "        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),\n",
        "        reinterpreted_batch_ndims=1)\n",
        "])\n",
        "\n",
        "ss = jds_ia.sample([5, 3])\n",
        "ss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "8fsYEy6Fla0o"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
              "array([[-13.1261215, -13.386831 , -14.021704 ],\n",
              "       [-15.311695 , -11.299437 , -13.955756 ],\n",
              "       [-11.30703  , -14.554242 , -12.285254 ],\n",
              "       [-13.204155 , -15.515974 , -11.195538 ],\n",
              "       [-12.403549 , -12.712912 ,  -9.4606905]], dtype=float32)>"
            ]
          },
          "execution_count": 22,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds_ia.log_prob(ss)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6ArLyKqJtY3Z"
      },
      "source": [
        "As an extra check, we'll verify that the log probability for a single batch point matches what we had before:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "9_2lIJyJtpyW"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(), dtype=float32, numpy=0.0>"
            ]
          },
          "execution_count": 23,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "(jds_ia.log_prob(ss)[3, 1] -\n",
        " jds_i.log_prob([ss[0][3, 1],\n",
        "                 ss[1][3, 1],\n",
        "                 ss[2][3, 1, :]]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_h7sJ2bkfOS7"
      },
      "source": [
        "<a id='AutoBatching-For-The-Win'></a>\n",
        "### AutoBatching For The Win\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J7nqIUMxuKzw"
      },
      "source": [
        "Excellent! We now have a version of JointDistribution that handles all our desiderata: `log_prob` returns a scalar thanks to the use of `tfd.Independent`, and multiple samples work now that we fixed broadcasting by adding extra axes.\n",
        "\n",
        "What if I told you there was an easier, better way? There is, and it's called `JointDistributionSequentialAutoBatched` (JDSAB):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "id": "LZtVljb0fRx2"
      },
      "outputs": [],
      "source": [
        "jds_ab = tfd.JointDistributionSequentialAutoBatched([\n",
        "    tfd.Normal(loc=0., scale=1.),   # m\n",
        "    tfd.Normal(loc=0., scale=1.),   # b\n",
        "    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y\n",
        "])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "id": "gpvjnvXqu2Mk"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(), dtype=float32, numpy=-10.550432>"
            ]
          },
          "execution_count": 25,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds_ab.log_prob(jds.sample())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "id": "Js3luiUfns_R"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
              "array([[-16.063435 , -11.415724 , -13.347199 ],\n",
              "       [-13.534442 , -20.753754 , -11.381274 ],\n",
              "       [-10.44528  , -12.624834 , -10.739721 ],\n",
              "       [-16.03442  , -13.358179 , -11.850428 ],\n",
              "       [ -9.4756365, -11.457652 , -10.145042 ]], dtype=float32)>"
            ]
          },
          "execution_count": 26,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "ss = jds_ab.sample([5, 3])\n",
        "jds_ab.log_prob(ss)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "id": "v1ppa6F6bdkv"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
              "array([[0., 0., 0.],\n",
              "       [0., 0., 0.],\n",
              "       [0., 0., 0.],\n",
              "       [0., 0., 0.],\n",
              "       [0., 0., 0.]], dtype=float32)>"
            ]
          },
          "execution_count": 27,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds_ab.log_prob(ss) - jds_ia.log_prob(ss)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xy-kuUbYwFB3"
      },
      "source": [
        "How does this work? While you could attempt to [read the code](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/distributions/joint_distribution_auto_batched.py#L426) for a deep understanding, we'll give a brief overview which is sufficient for most use cases:\n",
        "- Recall that our first problem was that our distribution for `Y` had `batch_shape=[7]` and `event_shape=[]`, and we used `Independent` to convert the batch dimension to an event dimension. JDSAB ignores the batch shapes of component distributions; instead it treats batch shape as an overall property of the model, which is assumed to be `[]` (unless specified otherwise by setting `batch_ndims > 0`). The effect is equivalent to using tfd.Independent to convert *all* batch dimensions of component distributions into event dimensions, as we did manually above.\n",
        "- Our second problem was a need to massage the shapes of `m` and `b` so that they could broadcast appropriately with `X` when creating multiple samples. With JDSAB, you write a model to generate a single sample, and we \"lift\" the entire model to generate multiple samples using TensorFlow's [vectorized_map](https://www.tensorflow.org/api_docs/python/tf/vectorized_map). (This feature is analagous to JAX's [vmap](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap).)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jUsWfVGqJiph"
      },
      "source": [
        "Exploring the batch shape issue in more detail, we can compare the batch shapes of our original \"bad\" joint distribution `jds`, our batch-fixed distributions `jds_i` and `jds_ia`, and our autobatched `jds_ab`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 28,
      "metadata": {
        "id": "298I732fJDk5"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[TensorShape([]), TensorShape([]), TensorShape([7])]"
            ]
          },
          "execution_count": 28,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds.batch_shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "id": "SBmdWrUuJGx0"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[TensorShape([]), TensorShape([]), TensorShape([])]"
            ]
          },
          "execution_count": 29,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds_i.batch_shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "id": "vD71eqN2JMhx"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[TensorShape([]), TensorShape([]), TensorShape([])]"
            ]
          },
          "execution_count": 30,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds_ia.batch_shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "metadata": {
        "id": "qHmvRcxBJOAZ"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "TensorShape([])"
            ]
          },
          "execution_count": 31,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds_ab.batch_shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ozegq0diJuOL"
      },
      "source": [
        "We see that the original `jds` has subdistributions with different batch shapes. `jds_i` and `jds_ia` fix this by creating subdistributions with the same (empty) batch shape. `jds_ab` has only a single (empty) batch shape."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bMm55xqV1dz6"
      },
      "source": [
        "It's worth noting that `JointDistributionSequentialAutoBatched` offers some additional generality for free. Suppose we make the covariates `X` (and, implicitly, the observations `Y`) two-dimensional:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "id": "1WfK-XbR1tXU"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([[ 0,  1,  2,  3,  4,  5,  6],\n",
              "       [ 7,  8,  9, 10, 11, 12, 13]])"
            ]
          },
          "execution_count": 32,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "X = np.arange(14).reshape((2, 7))\n",
        "X"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VOnnkZooSj2C"
      },
      "source": [
        "Our `JointDistributionSequentialAutoBatched` works with no changes (we need to redefine the model because the shape of `X` is cached by `jds_ab.log_prob`):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 33,
      "metadata": {
        "id": "6WwMvoY71qph"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
              " array([[-1.0845535 , -1.1255777 , -0.77237695],\n",
              "        [-1.2722294 ,  1.9274628 ,  0.75446165],\n",
              "        [ 1.214832  ,  2.03594   ,  0.68272597],\n",
              "        [-0.5651716 ,  1.6402307 ,  0.6128305 ],\n",
              "        [-0.01167952,  1.2298371 , -1.2706645 ]], dtype=float32)>,\n",
              " <tf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
              " array([[-0.5194242 ,  0.2823965 , -0.9434134 ],\n",
              "        [ 0.43568254, -0.37366644, -1.9174438 ],\n",
              "        [-0.8661425 , -1.4302185 ,  0.44063085],\n",
              "        [ 0.36433375, -0.38744366,  0.6491046 ],\n",
              "        [ 0.91218525,  0.36210015, -0.00910723]], dtype=float32)>,\n",
              " <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=\n",
              " array([[[[ -0.16874743,  -0.38901854,  -2.40703   ,  -5.930318  ,\n",
              "            -3.416317  ,  -7.0882726 ,  -6.631361  ],\n",
              "          [ -8.920654  , -10.499766  , -10.377804  , -12.9798355 ,\n",
              "           -11.721172  , -14.460028  , -14.922584  ]],\n",
              " \n",
              "         [[  0.50552297,  -0.9746385 ,  -2.047492  ,  -3.0749147 ,\n",
              "            -4.5619793 ,  -6.072114  ,  -5.1145515 ],\n",
              "          [ -7.2961216 ,  -8.094927  , -10.25211   , -12.26688   ,\n",
              "           -12.046576  , -15.34705   , -15.152906  ]],\n",
              " \n",
              "         [[ -0.8465157 ,  -2.6433449 ,  -0.76057017,  -3.1688592 ,\n",
              "            -4.687352  ,  -5.183547  ,  -5.0896225 ],\n",
              "          [ -6.222906  ,  -8.103443  ,  -7.795763  ,  -8.36684   ,\n",
              "           -10.562037  ,  -9.326081  ,  -9.593762  ]]],\n",
              " \n",
              " \n",
              "        [[[  1.054948  ,  -2.203673  ,  -3.035731  ,  -4.800442  ,\n",
              "            -5.2899976 ,  -5.9240775 ,  -6.730611  ],\n",
              "          [ -6.4754405 ,  -7.446973  , -10.764748  , -12.194825  ,\n",
              "           -11.556754  , -14.941436  , -14.943226  ]],\n",
              " \n",
              "         [[  0.87307787,   1.3859878 ,   2.6136284 ,   5.4836617 ,\n",
              "             5.8579865 ,  10.494877  ,  11.823118  ],\n",
              "          [ 11.510672  ,  14.746766  ,  16.719799  ,  18.618593  ,\n",
              "            21.580097  ,  22.609585  ,  25.759428  ]],\n",
              " \n",
              "         [[ -2.0380569 ,  -2.2008557 ,   0.43357986,   0.32134444,\n",
              "             0.36675143,   2.9957676 ,   1.6615164 ],\n",
              "          [  3.2243397 ,   3.220036  ,   4.315905  ,   6.7883563 ,\n",
              "             6.503477  ,   8.810654  ,   5.883856  ]]],\n",
              " \n",
              " \n",
              "        [[[ -0.477881  ,   1.4766507 ,   1.5208708 ,   3.147714  ,\n",
              "             2.9273605 ,   5.7710776 ,   7.128166  ],\n",
              "          [  7.3486524 ,   7.48754   ,   8.853534  ,  11.846103  ,\n",
              "            13.041363  ,  12.164903  ,  13.826527  ]],\n",
              " \n",
              "         [[ -2.935304  ,   0.5696763 ,   2.1498902 ,   6.319368  ,\n",
              "             7.923173  ,   8.151863  ,  11.570858  ],\n",
              "          [ 14.339904  ,  14.18277   ,  18.049622  ,  19.047941  ,\n",
              "            22.653297  ,  25.26222   ,  25.464987  ]],\n",
              " \n",
              "         [[  1.0329808 ,  -0.10444701,   0.99885136,   2.5327475 ,\n",
              "             2.0721416 ,   1.9450207 ,   4.6753073 ],\n",
              "          [  6.184873  ,   8.452423  ,   7.8260746 ,   7.713975  ,\n",
              "             7.0077796 ,  10.046227  ,  10.1453085 ]]],\n",
              " \n",
              " \n",
              "        [[[  0.3361371 ,  -0.62899804,   1.2562443 ,  -1.935529  ,\n",
              "            -1.4381697 ,  -1.5268946 ,  -3.8008852 ],\n",
              "          [ -4.1968484 ,  -6.028409  ,  -4.970623  ,  -4.9823346 ,\n",
              "            -5.6923776 ,  -6.535574  ,  -5.5532475 ]],\n",
              " \n",
              "         [[ -2.0243526 ,   3.3777661 ,   0.97641647,   4.6852875 ,\n",
              "             7.6430597 ,   5.8280125 ,   9.0458555 ],\n",
              "          [ 10.250172  ,  12.831018  ,  13.659218  ,  16.075794  ,\n",
              "            16.925209  ,  16.90435   ,  19.38226   ]],\n",
              " \n",
              "         [[  1.2758106 ,   0.83274007,   2.1775467 ,   3.1251085 ,\n",
              "             3.9337432 ,   2.543648  ,   5.1000204 ],\n",
              "          [  5.8442574 ,   6.0312934 ,   6.379141  ,   8.768039  ,\n",
              "             9.291983  ,   8.260785  ,   8.451964  ]]],\n",
              " \n",
              " \n",
              "        [[[  3.0444725 ,   0.73759735,   2.5216937 ,   0.04277879,\n",
              "             0.9555798 ,  -0.614954  ,   1.0725826 ],\n",
              "          [  3.0648081 ,   1.0510775 ,   0.9096012 ,   0.28714108,\n",
              "             1.4371622 ,   2.1362674 ,   1.9903467 ]],\n",
              " \n",
              "         [[  0.05708131,   1.2491966 ,   1.9845967 ,   3.4259818 ,\n",
              "             5.5484996 ,   7.8822956 ,   7.0572023 ],\n",
              "          [  9.535346  ,  11.390023  ,  10.360718  ,  12.881494  ,\n",
              "            11.301062  ,  13.86196   ,  16.829353  ]],\n",
              " \n",
              "         [[ -0.5573631 ,  -1.0938222 ,  -3.0080914 ,  -3.1928232 ,\n",
              "            -4.713949  ,  -7.016099  ,  -6.185412  ],\n",
              "          [ -8.42309   ,  -9.375599  , -10.624992  , -11.47895   ,\n",
              "           -14.62926   , -14.905938  , -18.084822  ]]]], dtype=float32)>]"
            ]
          },
          "execution_count": 33,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds_ab = tfd.JointDistributionSequentialAutoBatched([\n",
        "    tfd.Normal(loc=0., scale=1.),   # m\n",
        "    tfd.Normal(loc=0., scale=1.),   # b\n",
        "    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y\n",
        "])\n",
        "\n",
        "ss = jds_ab.sample([5, 3])\n",
        "ss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 34,
      "metadata": {
        "id": "GLvHMTpnSyvH"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
              "array([[-23.592081, -20.392092, -20.310911],\n",
              "       [-25.823744, -22.132751, -23.761002],\n",
              "       [-21.39077 , -27.747965, -25.098429],\n",
              "       [-21.14306 , -29.653296, -21.353765],\n",
              "       [-24.754295, -23.107279, -20.329145]], dtype=float32)>"
            ]
          },
          "execution_count": 34,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jds_ab.log_prob(ss)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AI40r2oETnVP"
      },
      "source": [
        "On the other hand, our carefully crafted `JointDistributionSequential` no longer works:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 35,
      "metadata": {
        "id": "tfYkdBIi0wJl"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]\n"
          ]
        }
      ],
      "source": [
        "jds_ia = tfd.JointDistributionSequential([\n",
        "    tfd.Normal(loc=0., scale=1.),   # m\n",
        "    tfd.Normal(loc=0., scale=1.),   # b\n",
        "    lambda b, m: tfd.Independent(   # Y\n",
        "        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),\n",
        "        reinterpreted_batch_ndims=1)\n",
        "])\n",
        "\n",
        "try:\n",
        "  jds_ia.sample([5, 3])\n",
        "except tf.errors.InvalidArgumentError as e:\n",
        "  print(e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WLERQvFNTwQJ"
      },
      "source": [
        "To fix this, we'd have to add a second `tf.newaxis` to both `m` and `b` match the shape, and increase `reinterpreted_batch_ndims` to 2 in the call to `Independent`. In this case, letting the auto-batching machinery handle the shape issues is shorter, easier, and more ergonomic."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HIgCF6yJXpHE"
      },
      "source": [
        "Once again, we note that while this notebook explored `JointDistributionSequentialAutoBatched`, the other variants of `JointDistribution` have equivalent `AutoBatched`. (For users of  `JointDistributionCoroutine`, `JointDistributionCoroutineAutoBatched` has the additional benefit that you no longer need to specify `Root` nodes; if you've never used `JointDistributionCoroutine` you can safely ignore this statement.)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mHacIM0iUW09"
      },
      "source": [
        "### Concluding Thoughts"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kXAC7GDWUaaY"
      },
      "source": [
        "In this notebook, we introduced `JointDistributionSequentialAutoBatched` and worked through a simple example in detail. Hopefully you learned something about TFP shapes and about autobatching!"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "JointDistributionAutoBatched_A_Gentle_Tutorial.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
